{ "cells": [ { "cell_type": "markdown", "id": "8lwheJQ8RxAw", "metadata": { "id": "8lwheJQ8RxAw" }, "source": [ "### **3. X-learner**\n", "Next, let's introduce the X-learner. As a combination of S-learner and T-learner, the X-learner can use information from the control(treatment) group to derive better estimators for the treatment(control) group, which is provably more efficient than the above two.\n", "\n", "The algorithm of X learner can be summarized as the following steps:\n", "\n", "\n", "**Step 1:** Estimate $\\mu_0(s)$ and $\\mu_1(s)$ separately with any regression algorithms or supervised machine learning methods (same as T-learner);\n", "\n", "\n", "**Step 2:** Obtain the imputed treatment effects for individuals\n", "\\begin{equation*}\n", "\\tilde{\\Delta}_i^1:=R_i^1-\\hat\\mu_0(S_i^1), \\quad \\tilde{\\Delta}_i^0:=\\hat\\mu_1(S_i^0)-R_i^0.\n", "\\end{equation*}\n", "\n", "**Step 3:** Fit the imputed treatment effects to obtain $\\hat\\tau_1(s):=\\mathbb{E}[\\tilde{\\Delta}_i^1|S=s]$ and $\\hat\\tau_0(s):=\\mathbb{E}[\\tilde{\\Delta}_i^0|S=s]$;\n", "\n", "**Step 4:** The final HTE estimator is given by\n", "\\begin{equation*}\n", "\\hat{\\tau}_{\\text{X-learner}}(s)=g(s)\\hat\\tau_0(s)+(1-g(s))\\hat\\tau_1(s),\n", "\\end{equation*}\n", "\n", "where $g(s)$ is a weight function between $[0,1]$. A possible way is to use the propensity score model as an estimate of $g(s)$." ] }, { "cell_type": "code", "execution_count": 2, "id": "eRpP5k9MBtzO", "metadata": { "id": "eRpP5k9MBtzO" }, "outputs": [], "source": [ "# import related packages\n", "import numpy as np\n", "import pandas as pd\n", "from matplotlib import pyplot as plt;\n", "from sklearn.ensemble import GradientBoostingRegressor\n", "from sklearn.linear_model import LinearRegression\n", "from causaldm.learners.CEL.Single_Stage import _env_getdata_CEL" ] }, { "cell_type": "markdown", "id": "XUu695Qrf61-", "metadata": { "id": "XUu695Qrf61-" }, "source": [ "### MovieLens Data" ] }, { "cell_type": "code", "execution_count": 3, "id": "JhfJntzcVVy2", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 424 }, "executionInfo": { "elapsed": 288, "status": "ok", "timestamp": 1676750101543, "user": { "displayName": "Yang Xu", "userId": "12270366590264264299" }, "user_tz": 300 }, "id": "JhfJntzcVVy2", "outputId": "7fab8a7a-7cd9-445c-a005-9a6d1994a071" }, "outputs": [ { "data": { "text/html": [ "
\n", " | user_id | \n", "movie_id | \n", "rating | \n", "age | \n", "Drama | \n", "Sci-Fi | \n", "gender_M | \n", "occupation_academic/educator | \n", "occupation_college/grad student | \n", "occupation_executive/managerial | \n", "occupation_other | \n", "occupation_technician/engineer | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "48.0 | \n", "1193.0 | \n", "4.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
1 | \n", "48.0 | \n", "919.0 | \n", "4.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
2 | \n", "48.0 | \n", "527.0 | \n", "5.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
3 | \n", "48.0 | \n", "1721.0 | \n", "4.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
4 | \n", "48.0 | \n", "150.0 | \n", "4.0 | \n", "25.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
65637 | \n", "5878.0 | \n", "3300.0 | \n", "2.0 | \n", "25.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
65638 | \n", "5878.0 | \n", "1391.0 | \n", "1.0 | \n", "25.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
65639 | \n", "5878.0 | \n", "185.0 | \n", "4.0 | \n", "25.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
65640 | \n", "5878.0 | \n", "2232.0 | \n", "1.0 | \n", "25.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
65641 | \n", "5878.0 | \n", "426.0 | \n", "3.0 | \n", "25.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
65642 rows × 12 columns
\n", "